import itertools
import arviz as az
import cmdstanpy
import pandas as pd
import numpy as np
from matplotlib import pyplot as pltNested numerical solving problems
Scientific knowledge often takes the form of specific relationships expressed by systems of equations. For example:
- A system of ordinary differential equations connects some state variables \(x\) with some other variables \(v\) with equations with the form \(\frac{dx}{dt}= f(x, v)\).
- An algebraic equation systems says that some variables \(v\) are related so that \(f(v) = 0\)
- A differential algebraic equation system says that some state variables’ rates of change have the form \(\frac{dx}{dt}= f(x, v)\) and that they satisfy some algebraic constraints \(f(x,v)=0\).
If there is an analytic solution to the equation system, we can just include the solution in our statistical model like any other form of structural knowledge: easy! However, often we want to solve equations that are hard or impossible to solve analytically, but can be solved approximately using numerical methods.
This is tricky in the context of Hamiltonian Monte Carlo for two reasons:
- Computation: HMC requires many evaluations of the log probability density function and its gradients.
At every evaluation, the sampler needs to solve the embedded equation system and find the gradients of the solution with respect to all model parameters.
- Extra source of error: how good of an approximation is good enough?
Reading:
- Timonen et al. (2022)
- Stan user guide sections: algebraic equation systems, ODE systems, DAE systems.
Example
We have some tubes containing a substrate \(S\) and some biomass \(C\) that we think approximately follow the Monod equation for microbial growth:
\[\begin{align*} \frac{dC}{dt} &= \frac{\mu_{max}\cdot S(t)}{K_{S} + S(t)}\cdot C(t) \\ \frac{dS}{dt} &= -\gamma \cdot \frac{\mu_{max}\cdot S(t)}{K_{s} + S(t)} \cdot C(t) \end{align*}\]
We measured \(C\) and \(S\) at different timepoints in some experiments and we want to try and find out \(\mu_{max}\), \(K_{S}\) and \(\gamma\) for the different strains in the tubes.
You can read more about the Monod equation in Allen and Waclaw (2019).
What we know
\(\mu_{max}, K_S, \gamma, S, C\) are non-negative.
\(S(0)\) and \(C(0)\) vary a little by tube.
\(\mu_{max}, K_S, \gamma\) vary by strain.
Measurement noise is roughly proportional to measured quantity.
Statistical model
We use two regression models to describe the measurements:
\[\begin{align*} y_C &\sim LN(\ln{\hat{C}}, \sigma_{C}) \\ y_S &\sim LN(\ln{\hat{S}}, \sigma_{S}) \end{align*}\]
To capture the variation in parameters by tube and strain we add a hierarchical regression model:
\[\begin{align*} \ln{\mu_{max}} &\sim N(a_{\mu_{max}}, \tau_{\mu_max}) \\ \ln{\gamma} &\sim N(a_{gamma}, \tau_{\gamma}) \\ \ln{\mu_{K_S}} &\sim N(a_{K_S}, \tau_{K_S}) \end{align*}\]
To get a true abundance given some parameters we put an ode in the model:
\[ \hat{C}(t), \hat{S}(t) = \text{solve-monod-equation}(t, C_0, S_0, \mu_max, \gamma, K_S) \]
imports
Specify true parameters
In order to avoid doing too much annoying handling of strings we assume that all the parts of the problem have meaningful 1-indexed integer labels: for example, species 1 is biomass.
This code specifies the dimensions of our problem.
N_strain = 4
N_tube = 16
N_timepoint = 20
duration = 15
strains = [i+1 for i in range(N_strain)]
tubes = [i+1 for i in range(N_tube)]
species = [1, 2]
measurement_timepoint_ixs = [4, 7, 12, 15, 17]
timepoints = pd.Series(
np.linspace(0.01, duration, N_timepoint),
name="time",
index=range(1, N_timepoint+1)
)
SEED = 12345
rng = np.random.default_rng(seed=SEED)This code defines some true values for the parameters - we will use these to generate fake data.
true_param_values = {
"a_mu_max": -1.7,
"a_ks": -1.3,
"a_gamma": -0.6,
"t_mu_max": 0.2,
"t_ks": 0.3,
"t_gamma": 0.13,
"species_zero": [
[
np.exp(np.random.normal(-2.1, 0.05)),
np.exp(np.random.normal(0.2, 0.05))
] for _ in range(N_tube)
],
"sigma_y": [0.08, 0.1],
"ln_mu_max_z": np.random.normal(0, 1, size=N_strain).tolist(),
"ln_ks_z": np.random.normal(0, 1, size=N_strain).tolist(),
"ln_gamma_z": np.random.normal(0, 1, size=N_strain).tolist(),
}
for var in ["mu_max", "ks", "gamma"]:
true_param_values[var] = np.exp(
true_param_values[f"a_{var}"]
+ true_param_values[f"t_{var}"] * np.array(true_param_values[f"ln_{var}_z"])
).tolist()A bit of data transformation
This code does some handy transformations on the data using pandas, giving us a table of information about the measurements.
tube_to_strain = pd.Series(
[
(i % N_strain) + 1 for i in range(N_tube) # % operator finds remainder
], index=tubes, name="strain"
)
measurements = (
pd.DataFrame(
itertools.product(tubes, measurement_timepoint_ixs, species),
columns=["tube", "timepoint", "species"],
index=range(1, len(tubes) * len(measurement_timepoint_ixs) * len(species) + 1)
)
.join(tube_to_strain, on="tube")
.join(timepoints, on="timepoint")
)Generating a Stan input dictionary
This code puts the data in the correct format for cmdstanpy.
stan_input_structure = {
"N_measurement": len(measurements),
"N_timepoint": N_timepoint,
"N_tube": N_tube,
"N_strain": N_strain,
"tube": measurements["tube"].values.tolist(),
"measurement_timepoint": measurements["timepoint"].values.tolist(),
"measured_species": measurements["species"].values.tolist(),
"strain": tube_to_strain.values.tolist(),
"timepoint_time": timepoints.values.tolist(),
}This code defines some prior distributions for the model’s parameters
priors = {
# parameters that can be negative:
"prior_a_mu_max": [-1.8, 0.2],
"prior_a_ks": [-1.3, 0.1],
"prior_a_gamma": [-0.5, 0.1],
# parameters that are non-negative:
"prior_t_mu_max": [-1.4, 0.1],
"prior_t_ks": [-1.2, 0.1],
"prior_t_gamma": [-2, 0.1],
"prior_species_zero": [[[-2.1, 0.1], [0.2, 0.1]]] * N_tube,
"prior_sigma_y": [[-2.3, 0.15], [-2.3, 0.15]],
}The next bit of code lets us configure Stan’s interface to the Sundials ODE solver.
ode_solver_configuration = {
"abs_tol": 1e-7,
"rel_tol": 1e-7,
"max_num_steps": int(1e7)
}Now we can put all the inputs together
stan_input_common = stan_input_structure | priors | ode_solver_configurationLoad the model
This code loads the Stan program at monod.stan as a CmdStanModel object and compiles it using cmdstan’s compiler.
model = cmdstanpy.CmdStanModel(stan_file="../src/stan/monod.stan")
print(model.code())functions {
real get_mu_at_t(real mu_max, real ks, real S_at_t) {
return (mu_max * S_at_t) / (ks + S_at_t);
}
vector ddt(real t, vector species, real mu_max, real ks, real gamma) {
real mu_at_t = get_mu_at_t(mu_max, ks, species[2]);
vector[2] out;
out[1] = mu_at_t * species[1];
out[2] = -gamma * mu_at_t * species[1];
return out;
}
}
data {
int<lower=1> N_measurement;
int<lower=1> N_timepoint;
int<lower=1> N_tube;
int<lower=1> N_strain;
array[N_measurement] int<lower=1, upper=N_tube> tube;
array[N_measurement] int<lower=1, upper=N_timepoint> measurement_timepoint;
array[N_measurement] int<lower=1, upper=2> measured_species;
vector<lower=0>[N_measurement] y;
array[N_tube] int<lower=1, upper=N_strain> strain;
array[N_timepoint] real<lower=0> timepoint_time;
array[N_tube, 2] vector[2] prior_species_zero;
array[2] vector[2] prior_sigma_y;
vector[2] prior_a_mu_max;
vector[2] prior_a_ks;
vector[2] prior_a_gamma;
vector[2] prior_t_mu_max;
vector[2] prior_t_gamma;
vector[2] prior_t_ks;
real<lower=0> abs_tol;
real<lower=0> rel_tol;
int<lower=1> max_num_steps;
int<lower=0, upper=1> likelihood;
}
parameters {
vector[N_strain] ln_mu_max_z;
vector[N_strain] ln_ks_z;
vector[N_strain] ln_gamma_z;
real a_mu_max;
real a_ks;
real a_gamma;
real<lower=0> t_mu_max;
real<lower=0> t_ks;
real<lower=0> t_gamma;
array[N_tube] vector<lower=0>[2] species_zero;
vector<lower=0>[2] sigma_y;
}
transformed parameters {
vector[N_strain] mu_max = exp(a_mu_max + ln_mu_max_z * t_mu_max);
vector[N_strain] ks = exp(a_ks + ln_ks_z * t_ks);
vector[N_strain] gamma = exp(a_gamma + ln_gamma_z * t_gamma);
array[N_tube, N_timepoint] vector[2] abundance;
for (tube_t in 1 : N_tube) {
abundance[tube_t] = ode_bdf_tol(ddt, species_zero[tube_t], 0,
timepoint_time,
abs_tol, rel_tol, max_num_steps,
mu_max[strain[tube_t]],
ks[strain[tube_t]], gamma[strain[tube_t]]);
}
}
model {
// priors
ln_mu_max_z ~ std_normal();
ln_ks_z ~ std_normal();
ln_gamma_z ~ std_normal();
a_mu_max ~ normal(prior_a_mu_max[1], prior_a_mu_max[2]);
a_ks ~ normal(prior_a_ks[1], prior_a_ks[2]);
a_gamma ~ normal(prior_a_gamma[1], prior_a_gamma[2]);
t_mu_max ~ lognormal(prior_t_mu_max[1], prior_t_mu_max[2]);
t_ks ~ lognormal(prior_t_ks[1], prior_t_ks[2]);
t_gamma ~ lognormal(prior_t_gamma[1], prior_t_gamma[2]);
for (s in 1 : 2) {
sigma_y[s] ~ lognormal(prior_sigma_y[s, 1], prior_sigma_y[s, 2]);
for (t in 1 : N_tube){
species_zero[t, s] ~ lognormal(prior_species_zero[t, s, 1],
prior_species_zero[t, s, 2]);
}
}
// likelihood
if (likelihood) {
for (m in 1 : N_measurement) {
real yhat = abundance[tube[m], measurement_timepoint[m], measured_species[m]];
y[m] ~ lognormal(log(yhat), sigma_y[measured_species[m]]);
}
}
}
generated quantities {
vector[N_measurement] yrep;
vector[N_measurement] llik;
for (m in 1 : N_measurement){
real yhat = abundance[tube[m], measurement_timepoint[m], measured_species[m]];
yrep[m] = lognormal_rng(log(yhat), sigma_y[measured_species[m]]);
llik[m] = lognormal_lpdf(y[m] | log(yhat), sigma_y[measured_species[m]]);
}
}
Sample in fixed param mode to generate fake data
stan_input_true = stan_input_common | {
"y": np.ones(len(measurements)).tolist(), # dummy values as we don't need measurements yet
"likelihood": 0 # we don't need to evaluate the likelihood
}
coords = {
"strain": strains,
"tube": tubes,
"species": species,
"timepoint": timepoints.index.values,
"measurement": measurements.index.values
}
dims = {
"abundance": ["tube", "timepoint", "species"],
"mu_max": ["strain"],
"ks": ["strain"],
"gamma": ["strain"],
"species_zero": ["tube", "species"],
"y": ["measurement"],
"yrep": ["measurement"],
"llik": ["measurement"]
}
mcmc_true = model.sample(
data=stan_input_true,
iter_sampling=1,
fixed_param=True,
chains=1,
refresh=1,
inits=true_param_values,
seed=SEED,
)
idata_true = az.from_cmdstanpy(
mcmc_true,
dims=dims,
coords=coords,
posterior_predictive={"y": "yrep"},
log_likelihood="llik"
)17:22:07 - cmdstanpy - INFO - CmdStan start processing
17:22:07 - cmdstanpy - INFO - CmdStan done processing.
Look at results
def plot_sim(true_abundance, fake_measurements, species_to_ax):
f, axes = plt.subplots(1, 2, figsize=[9, 3])
axes[species_to_ax[1]].set_title("Species 1")
axes[species_to_ax[2]].set_title("Species 2")
for ax in axes:
ax.set_xlabel("Time")
ax.set_ylabel("Abundance")
for (tube_i, species_i), df_i in true_abundance.groupby(["tube", "species"]):
ax = axes[species_to_ax[species_i]]
fm = df_i.merge(
fake_measurements.drop("time", axis=1),
on=["tube", "species", "timepoint"]
)
ax.plot(
df_i.set_index("time")["abundance"], color="black", linewidth=0.5
)
ax.scatter(
fm["time"],
fm["simulated_measurement"],
color="r",
marker="x",
label="simulated measurement"
)
return f, axes
species_to_ax = {1: 0, 2: 1}
true_abundance = (
idata_true.posterior["abundance"]
.to_dataframe()
.droplevel(["chain", "draw"])
.join(timepoints, on="timepoint")
.reset_index()
)
fake_measurements = measurements.join(
idata_true.posterior_predictive["yrep"]
.to_series()
.droplevel(["chain", "draw"])
.rename("simulated_measurement")
).copy()
f, axes = plot_sim(true_abundance, fake_measurements, species_to_ax)
f.savefig("img/monod_simulated_data.png")Sample in prior mode
stan_input_prior = stan_input_common | {
"y": fake_measurements["simulated_measurement"],
"likelihood": 0
}
mcmc_prior = model.sample(
data=stan_input_prior,
iter_warmup=100,
iter_sampling=100,
chains=1,
refresh=1,
save_warmup=True,
inits=true_param_values,
seed=SEED,
)
idata_prior = az.from_cmdstanpy(
mcmc_prior,
dims=dims,
coords=coords,
posterior_predictive={"y": "yrep"},
log_likelihood="llik"
)
idata_prior17:22:07 - cmdstanpy - INFO - CmdStan start processing
17:22:38 - cmdstanpy - INFO - CmdStan done processing.
17:22:38 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -4:
Exception: lognormal_rng: Location parameter is nan, but must be finite! (in 'monod.stan', line 94, column 4 to column 69)
Consider re-running with show_console=True if the above output is unclear!
-
<xarray.Dataset> Size: 564kB Dimensions: (chain: 1, draw: 100, ln_mu_max_z_dim_0: 4, ln_ks_z_dim_0: 4, ln_gamma_z_dim_0: 4, tube: 16, species: 2, sigma_y_dim_0: 2, strain: 4, timepoint: 20) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 800B 0 1 2 3 4 5 6 ... 93 94 95 96 97 98 99 * ln_mu_max_z_dim_0 (ln_mu_max_z_dim_0) int64 32B 0 1 2 3 * ln_ks_z_dim_0 (ln_ks_z_dim_0) int64 32B 0 1 2 3 * ln_gamma_z_dim_0 (ln_gamma_z_dim_0) int64 32B 0 1 2 3 * tube (tube) int64 128B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 * species (species) int64 16B 1 2 * sigma_y_dim_0 (sigma_y_dim_0) int64 16B 0 1 * strain (strain) int64 32B 1 2 3 4 * timepoint (timepoint) int64 160B 1 2 3 4 5 6 ... 15 16 17 18 19 20 Data variables: (12/15) ln_mu_max_z (chain, draw, ln_mu_max_z_dim_0) float64 3kB 0.312 ...... ln_ks_z (chain, draw, ln_ks_z_dim_0) float64 3kB 0.3093 ... -0... ln_gamma_z (chain, draw, ln_gamma_z_dim_0) float64 3kB -0.1984 ..... a_mu_max (chain, draw) float64 800B -1.883 -1.708 ... -1.867 -1.84 a_ks (chain, draw) float64 800B -1.274 -1.282 ... -1.262 a_gamma (chain, draw) float64 800B -0.7959 -0.7954 ... -0.5912 ... ... species_zero (chain, draw, tube, species) float64 26kB 0.1295 ... 1... sigma_y (chain, draw, sigma_y_dim_0) float64 2kB 0.1207 ... 0.... mu_max (chain, draw, strain) float64 3kB 0.1656 ... 0.1197 ks (chain, draw, strain) float64 3kB 0.3079 ... 0.2662 gamma (chain, draw, strain) float64 3kB 0.435 0.3668 ... 0.4503 abundance (chain, draw, tube, timepoint, species) float64 512kB ... Attributes: created_at: 2024-05-03T15:22:38.991208 arviz_version: 0.17.1 inference_library: cmdstanpy inference_library_version: 1.2.1 -
<xarray.Dataset> Size: 130kB Dimensions: (chain: 1, draw: 100, measurement: 160) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 * measurement (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160 Data variables: yrep (chain, draw, measurement) float64 128kB 0.1828 1.363 ... 1.036 Attributes: created_at: 2024-05-03T15:22:38.996502 arviz_version: 0.17.1 inference_library: cmdstanpy inference_library_version: 1.2.1 -
<xarray.Dataset> Size: 130kB Dimensions: (chain: 1, draw: 100, measurement: 160) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 * measurement (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160 Data variables: llik (chain, draw, measurement) float64 128kB 2.658 ... -2.064 Attributes: created_at: 2024-05-03T15:22:38.997377 arviz_version: 0.17.1 inference_library: cmdstanpy inference_library_version: 1.2.1 -
<xarray.Dataset> Size: 11kB Dimensions: (chain: 1, draw: 200) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 2kB 0 1 2 3 4 5 6 ... 194 195 196 197 198 199 Data variables: lp (chain, draw) float64 2kB -54.29 -54.29 ... -49.47 -54.23 acceptance_rate (chain, draw) float64 2kB 0.9259 0.0 0.0 ... 0.8613 0.629 step_size (chain, draw) float64 2kB 0.03125 12.57 ... 0.0687 0.0687 tree_depth (chain, draw) int64 2kB 7 0 0 4 8 7 7 7 ... 6 6 6 6 6 6 6 6 n_steps (chain, draw) int64 2kB 127 1 1 15 255 ... 63 63 63 63 63 diverging (chain, draw) bool 200B False True True ... False False energy (chain, draw) float64 2kB 76.46 78.45 76.5 ... 78.36 73.85 Attributes: created_at: 2024-05-03T15:22:38.995052 arviz_version: 0.17.1 inference_library: cmdstanpy inference_library_version: 1.2.1
We can find the prior intervals for the true abundance and plot them in the graph.
prior_abundances = idata_prior.posterior["abundance"]
n_sample = 20
chains = rng.choice(prior_abundances.coords["chain"].values, n_sample)
draws = rng.choice(prior_abundances.coords["draw"].values, n_sample)
f, axes = plot_sim(true_abundance, fake_measurements, species_to_ax)
for ax, species_i in zip(axes, species):
for tube_j in tubes:
for chain, draw in zip(chains, draws):
timeseries = prior_abundances.sel(chain=chain, draw=draw, tube=tube_j, species=species_i)
ax.plot(
timepoints.values,
timeseries.values,
alpha=0.5, color="skyblue", zorder=-1
)
f.savefig("img/monod_priors.png")Sample in posterior mode
stan_input_posterior = stan_input_common | {
"y": fake_measurements["simulated_measurement"],
"likelihood": 1
}
mcmc_posterior = model.sample(
data=stan_input_posterior,
iter_warmup=300,
iter_sampling=300,
chains=4,
refresh=1,
inits=true_param_values,
seed=SEED,
)
idata_posterior = az.from_cmdstanpy(
mcmc_posterior,
dims=dims,
coords=coords,
posterior_predictive={"y": "yrep"},
log_likelihood="llik"
)
idata_posterior17:22:39 - cmdstanpy - INFO - CmdStan start processing
17:26:32 - cmdstanpy - INFO - CmdStan done processing.
17:26:32 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: initial state[2] is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Exception: ode_bdf_tol: initial state[1] is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Consider re-running with show_console=True if the above output is unclear!
-
<xarray.Dataset> Size: 7MB Dimensions: (chain: 4, draw: 300, ln_mu_max_z_dim_0: 4, ln_ks_z_dim_0: 4, ln_gamma_z_dim_0: 4, tube: 16, species: 2, sigma_y_dim_0: 2, strain: 4, timepoint: 20) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 ... 294 295 296 297 298 299 * ln_mu_max_z_dim_0 (ln_mu_max_z_dim_0) int64 32B 0 1 2 3 * ln_ks_z_dim_0 (ln_ks_z_dim_0) int64 32B 0 1 2 3 * ln_gamma_z_dim_0 (ln_gamma_z_dim_0) int64 32B 0 1 2 3 * tube (tube) int64 128B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 * species (species) int64 16B 1 2 * sigma_y_dim_0 (sigma_y_dim_0) int64 16B 0 1 * strain (strain) int64 32B 1 2 3 4 * timepoint (timepoint) int64 160B 1 2 3 4 5 6 ... 15 16 17 18 19 20 Data variables: (12/15) ln_mu_max_z (chain, draw, ln_mu_max_z_dim_0) float64 38kB 1.18 ...... ln_ks_z (chain, draw, ln_ks_z_dim_0) float64 38kB -0.1055 ... ... ln_gamma_z (chain, draw, ln_gamma_z_dim_0) float64 38kB 1.57 ... ... a_mu_max (chain, draw) float64 10kB -1.846 -1.721 ... -1.775 a_ks (chain, draw) float64 10kB -1.243 -1.19 ... -1.511 -1.161 a_gamma (chain, draw) float64 10kB -0.5575 -0.6314 ... -0.3633 ... ... species_zero (chain, draw, tube, species) float64 307kB 0.1208 ... ... sigma_y (chain, draw, sigma_y_dim_0) float64 19kB 0.07315 ... ... mu_max (chain, draw, strain) float64 38kB 0.2116 ... 0.1125 ks (chain, draw, strain) float64 38kB 0.2798 ... 0.3055 gamma (chain, draw, strain) float64 38kB 0.7078 ... 0.7194 abundance (chain, draw, tube, timepoint, species) float64 6MB 0.... Attributes: created_at: 2024-05-03T15:26:32.786838 arviz_version: 0.17.1 inference_library: cmdstanpy inference_library_version: 1.2.1 -
<xarray.Dataset> Size: 2MB Dimensions: (chain: 4, draw: 300, measurement: 160) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 ... 293 294 295 296 297 298 299 * measurement (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160 Data variables: yrep (chain, draw, measurement) float64 2MB 0.2065 1.266 ... 1.048 Attributes: created_at: 2024-05-03T15:26:32.792594 arviz_version: 0.17.1 inference_library: cmdstanpy inference_library_version: 1.2.1 -
<xarray.Dataset> Size: 2MB Dimensions: (chain: 4, draw: 300, measurement: 160) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 ... 293 294 295 296 297 298 299 * measurement (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160 Data variables: llik (chain, draw, measurement) float64 2MB 3.13 0.7924 ... -2.103 Attributes: created_at: 2024-05-03T15:26:32.793700 arviz_version: 0.17.1 inference_library: cmdstanpy inference_library_version: 1.2.1 -
<xarray.Dataset> Size: 61kB Dimensions: (chain: 4, draw: 300) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 ... 294 295 296 297 298 299 Data variables: lp (chain, draw) float64 10kB 113.2 107.4 ... 94.53 103.2 acceptance_rate (chain, draw) float64 10kB 0.7955 0.9379 ... 0.7681 0.9696 step_size (chain, draw) float64 10kB 0.04984 0.04984 ... 0.05528 tree_depth (chain, draw) int64 10kB 6 6 6 6 6 6 7 6 ... 6 6 6 6 6 6 6 n_steps (chain, draw) int64 10kB 127 127 127 63 63 ... 63 63 63 63 diverging (chain, draw) bool 1kB False False False ... False False energy (chain, draw) float64 10kB -85.05 -89.45 ... -76.39 -74.55 Attributes: created_at: 2024-05-03T15:26:32.790936 arviz_version: 0.17.1 inference_library: cmdstanpy inference_library_version: 1.2.1
Diagnostics: is the posterior ok?
First check the sample_stats group to see if there were any divergent transitions and if the lp parameter converged.
az.summary(idata_posterior.sample_stats)/Users/tedgro/repos/biosustain/cmfa/.venv/lib/python3.12/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| lp | 104.557 | 5.397 | 93.506 | 113.715 | 0.265 | 0.188 | 425.0 | 650.0 | 1.000000e+00 |
| acceptance_rate | 0.934 | 0.081 | 0.782 | 1.000 | 0.002 | 0.002 | 1311.0 | 1216.0 | 1.010000e+00 |
| step_size | 0.049 | 0.004 | 0.044 | 0.055 | 0.002 | 0.002 | 4.0 | 1200.0 | 5.638503e+15 |
| tree_depth | 6.348 | 0.476 | 6.000 | 7.000 | 0.133 | 0.096 | 13.0 | 13.0 | 1.250000e+00 |
| n_steps | 96.013 | 33.048 | 63.000 | 127.000 | 8.526 | 6.150 | 16.0 | 14.0 | 1.190000e+00 |
| diverging | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 1200.0 | 1200.0 | NaN |
| energy | -78.516 | 7.231 | -91.897 | -65.401 | 0.379 | 0.269 | 373.0 | 496.0 | 1.000000e+00 |
Next check the parameter-by-parameter summary
az.summary(idata_posterior)| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| ln_mu_max_z[0] | 0.812 | 0.482 | -0.086 | 1.733 | 0.025 | 0.018 | 375.0 | 494.0 | 1.01 |
| ln_mu_max_z[1] | 1.041 | 0.491 | 0.025 | 1.862 | 0.023 | 0.016 | 470.0 | 619.0 | 1.00 |
| ln_mu_max_z[2] | -0.269 | 0.492 | -1.170 | 0.661 | 0.025 | 0.018 | 374.0 | 514.0 | 1.01 |
| ln_mu_max_z[3] | -1.526 | 0.520 | -2.491 | -0.540 | 0.026 | 0.018 | 420.0 | 539.0 | 1.01 |
| ln_ks_z[0] | -0.244 | 0.931 | -1.968 | 1.450 | 0.026 | 0.024 | 1315.0 | 809.0 | 1.00 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| abundance[16, 18, 2] | 1.059 | 0.049 | 0.966 | 1.154 | 0.001 | 0.001 | 1593.0 | 878.0 | 1.00 |
| abundance[16, 19, 1] | 0.441 | 0.019 | 0.404 | 0.477 | 0.000 | 0.000 | 1653.0 | 1007.0 | 1.00 |
| abundance[16, 19, 2] | 1.041 | 0.050 | 0.948 | 1.138 | 0.001 | 0.001 | 1620.0 | 827.0 | 1.00 |
| abundance[16, 20, 1] | 0.472 | 0.021 | 0.428 | 0.510 | 0.001 | 0.000 | 1601.0 | 1083.0 | 1.00 |
| abundance[16, 20, 2] | 1.022 | 0.051 | 0.928 | 1.121 | 0.001 | 0.001 | 1649.0 | 847.0 | 1.00 |
704 rows × 9 columns
Show posterior intervals
prior_abundances = idata_posterior.posterior["abundance"]
n_sample = 20
chains = rng.choice(prior_abundances.coords["chain"].values, n_sample)
draws = rng.choice(prior_abundances.coords["draw"].values, n_sample)
f, axes = plot_sim(true_abundance, fake_measurements, species_to_ax)
for ax, species_i in zip(axes, species):
for tube_j in tubes:
for chain, draw in zip(chains, draws):
timeseries = prior_abundances.sel(chain=chain, draw=draw, tube=tube_j, species=species_i)
ax.plot(
timepoints.values,
timeseries.values,
alpha=0.5, color="skyblue", zorder=-1
)
f.savefig("img/monod_posteriors.png")look at the posterior
The next few cells use arviz’s plot_posterior function to plot the marginal posterior distributions for some of the model’s parameters:
f, axes = plt.subplots(1, 4, figsize=[10, 4])
axes = az.plot_posterior(
idata_posterior,
kind="hist",
bins=20,
var_names=["gamma"],
ax=axes,
point_estimate=None,
hdi_prob="hide"
)
for ax, true_value in zip(axes, true_param_values["gamma"]):
ax.axvline(true_value, color="red")f, axes = plt.subplots(1, 4, figsize=[10, 4])
axes = az.plot_posterior(
idata_posterior,
kind="hist",
bins=20,
var_names=["mu_max"],
ax=axes,
point_estimate=None,
hdi_prob="hide"
)
for ax, true_value in zip(axes, true_param_values["mu_max"]):
ax.axvline(true_value, color="red")f, axes = plt.subplots(1, 4, figsize=[10, 4])
axes = az.plot_posterior(
idata_posterior,
kind="hist",
bins=20,
var_names=["ks"],
ax=axes,
point_estimate=None,
hdi_prob="hide"
)
for ax, true_value in zip(axes, true_param_values["ks"]):
ax.axvline(true_value, color="red")